// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bindinfo

import (
	"container/list"
	"fmt"
	"sort"
	"strconv"
	"strings"

	"github.com/pingcap/tidb/pkg/parser"
	"github.com/pingcap/tidb/pkg/parser/ast"
	"github.com/pingcap/tidb/pkg/planner/util/fixcontrol"
	"github.com/pingcap/tidb/pkg/sessionctx"
	"github.com/pingcap/tidb/pkg/sessionctx/vardef"
	"github.com/pingcap/tidb/pkg/util"
	"github.com/pingcap/tidb/pkg/util/hint"
)

// PlanGenerator is used to generate new Plan Candidates for this specified query.
type PlanGenerator interface {
	Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error)
}

// planGenerator implements PlanGenerator.
// It generates new plans via adjusting the optimizer variables and fixes.
type planGenerator struct {
	sPool util.DestroyableSessionPool
}

// Generate generates new plans for the given SQL statement.
func (g *planGenerator) Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) {
	// TODO: only support SQL starting with SELECT for now, support other types of SQLs later.
	// TODO: make this check more strict.
	sql = strings.TrimSpace(sql)
	prefix := "SELECT"
	if len(sql) < len(prefix) || strings.ToUpper(sql[:len(prefix)]) != prefix {
		return nil, nil // not a SELECT statement
	}

	err = callWithSCtx(g.sPool, false, func(sctx sessionctx.Context) error {
		genedPlans, err := generatePlanWithSCtx(sctx, defaultSchema, sql, charset, collation)
		if err != nil {
			return err
		}
		plans = make([]*BindingPlanInfo, 0, len(genedPlans))

		for _, genedPlan := range genedPlans {
			// TODO: construct bindingSQL in a more strict way.
			bindingSQL := sql[:len(prefix)] + " /*+ " + genedPlan.planHints + " */ " + sql[len(prefix):]
			binding := &Binding{
				OriginalSQL: sql,
				BindSQL:     bindingSQL,
				Db:          defaultSchema,
				Source:      "generated",
				PlanDigest:  genedPlan.planDigest,
			}
			if err := prepareHints(sctx, binding); err != nil {
				return err
			}
			plan := &BindingPlanInfo{
				Binding: binding,
				Plan:    genedPlan.PlanText(),
			}
			plans = append(plans, plan)
		}
		return nil
	})
	return
}

type tableName struct {
	schema string
	name   string
}

func (t *tableName) String() string {
	return fmt.Sprintf("%s.%s", t.schema, t.name)
}

// genedPlan represents a plan generated by planGenerator.
type genedPlan struct {
	planDigest string     // digest of this plan
	planHints  string     // a set of hints to reproduce this plan
	planText   [][]string // human-readable plan text
}

func (gp *genedPlan) PlanText() string {
	sb := new(strings.Builder)
	for i, row := range gp.planText {
		if i > 0 {
			sb.WriteString("\n")
		}
		for j, col := range row {
			if j > 0 {
				sb.WriteString("\t")
			}
			sb.WriteString(col)
		}
	}
	return sb.String()
}

// state represents a state of the optimizer variables and fixes.
type state struct {
	leading2  [2]*tableName // leading-2 table names
	varNames  []string      // relevant variables and their values to generate a certain plan
	varValues []any
	fixIDs    []uint64 // relevant fixes and their values to generate a certain plan
	fixValues []string
}

// Encode encodes the state into a string.
func (s *state) Encode() string {
	sb := new(strings.Builder)
	for _, t := range s.leading2 {
		if t == nil {
			continue
		}
		if sb.Len() > 0 {
			sb.WriteString(",")
		}
		sb.WriteString(t.String())
	}
	for _, v := range s.varValues {
		if sb.Len() > 0 {
			sb.WriteString(",")
		}
		if _, isFloat := v.(float64); isFloat {
			// only consider 4 decimal digits, which should be enough for optimizer tuning.
			fmt.Fprintf(sb, "%.4f", v)
			continue
		}
		fmt.Fprintf(sb, "%v", v)
	}
	for _, v := range s.fixValues {
		if sb.Len() > 0 {
			sb.WriteString(",")
		}
		sb.WriteString(v)
	}
	return sb.String()
}

func newStateWithLeading2(old *state, leading2 [2]*tableName) *state {
	newState := &state{
		leading2:  leading2,
		varNames:  old.varNames,
		varValues: old.varValues,
		fixIDs:    old.fixIDs,
		fixValues: old.fixValues,
	}
	return newState
}

func newStateWithNewVar(old *state, varName string, varVal any) *state {
	newState := &state{
		leading2:  old.leading2,
		varNames:  old.varNames,
		varValues: make([]any, len(old.varValues)),
		fixIDs:    old.fixIDs,
		fixValues: old.fixValues,
	}
	copy(newState.varValues, old.varValues)
	for i := range newState.varNames {
		if newState.varNames[i] == varName {
			newState.varValues[i] = varVal
			break
		}
	}
	return newState
}

func newStateWithNewFix(old *state, fixID uint64, fixVal string) *state {
	newState := &state{
		leading2:  old.leading2,
		varNames:  old.varNames,
		varValues: old.varValues,
		fixIDs:    old.fixIDs,
		fixValues: make([]string, len(old.fixValues)),
	}
	copy(newState.fixValues, old.fixValues)
	for i := range newState.fixIDs {
		if newState.fixIDs[i] == fixID {
			newState.fixValues[i] = fixVal
			break
		}
	}
	return newState
}

func generatePlanWithSCtx(sctx sessionctx.Context, defaultSchema, sql, charset, collation string) (plans []*genedPlan, err error) {
	p := parser.New()
	stmt, err := p.ParseOneStmt(sql, charset, collation)
	if err != nil {
		return nil, err
	}
	sctx.GetSessionVars().CurrentDB = defaultSchema
	sctx.GetSessionVars().CostModelVersion = 2 // cost factor only works on cost-model v2
	vars, fixes, err := RecordRelevantOptVarsAndFixes(sctx, stmt)
	if err != nil {
		return nil, err
	}
	tableNames := extractSelectTableNames(defaultSchema, stmt)
	possibleLeading2 := make([][2]*tableName, 0, 8) // enumerate all possible leading-2 table pairs
	for i := range tableNames {
		for j := range tableNames {
			if i == j {
				continue
			}
			possibleLeading2 = append(possibleLeading2, [2]*tableName{tableNames[i], tableNames[j]})
		}
	}
	return breadthFirstPlanSearch(sctx, stmt, vars, fixes, possibleLeading2)
}

func breadthFirstPlanSearch(sctx sessionctx.Context, stmt ast.StmtNode,
	vars []string, fixes []uint64, possibleLeading2 [][2]*tableName) (plans []*genedPlan, err error) {
	// init BFS structures
	visitedStates := make(map[string]struct{})  // map[encodedState]struct{}, all visited states
	visitedPlans := make(map[string]*genedPlan) // map[planDigest]plan, all visited plans
	stateList := list.New()                     // states in queue to explore

	// init the start state and push it into the BFS list
	// start state: no specified leading hint + default values of all variables and fix-controls
	startState, err := getStartState(vars, fixes)
	if err != nil {
		return nil, err
	}
	visitedStates[startState.Encode()] = struct{}{}
	stateList.PushBack(startState)

	maxPlans, maxExploreState := 30, 5000
	for len(visitedPlans) < maxPlans && len(visitedStates) < maxExploreState && stateList.Len() > 0 {
		currState := stateList.Remove(stateList.Front()).(*state)
		plan, err := genPlanUnderState(sctx, stmt, currState)
		if err != nil {
			return nil, err
		}
		visitedPlans[plan.planDigest] = plan

		// in each step, adjust one variable or fix or join-order
		for _, leading2 := range possibleLeading2 {
			newState := newStateWithLeading2(currState, leading2)
			if _, ok := visitedStates[newState.Encode()]; !ok {
				visitedStates[newState.Encode()] = struct{}{}
				stateList.PushBack(newState)
			}
		}
		for i := range vars {
			varName, varVal := vars[i], currState.varValues[i]
			newVarVal, err := adjustVar(varName, varVal)
			if err != nil {
				return nil, err
			}
			newState := newStateWithNewVar(currState, varName, newVarVal)
			if _, ok := visitedStates[newState.Encode()]; !ok {
				visitedStates[newState.Encode()] = struct{}{}
				stateList.PushBack(newState)
			}
		}
		for i := range fixes {
			fixID, fixVal := fixes[i], currState.fixValues[i]
			newFixVal, err := adjustFix(fixID, fixVal)
			if err != nil {
				return nil, err
			}
			newState := newStateWithNewFix(currState, fixID, newFixVal)
			if _, ok := visitedStates[newState.Encode()]; !ok {
				visitedStates[newState.Encode()] = struct{}{}
				stateList.PushBack(newState)
			}
		}
	}

	plans = make([]*genedPlan, 0, len(visitedPlans))
	for _, plan := range visitedPlans {
		plans = append(plans, plan)
	}
	sort.Slice(plans, func(i, j int) bool { // to make the result stable
		return plans[i].planDigest < plans[j].planDigest
	})
	return plans, nil
}

// genPlanUnderState returns a plan generated under the given state (vars and fix-controls).
func genPlanUnderState(sctx sessionctx.Context, stmt ast.StmtNode, state *state) (plan *genedPlan, err error) {
	for i, varName := range state.varNames {
		switch varName {
		case vardef.TiDBOptIndexScanCostFactor:
			sctx.GetSessionVars().IndexScanCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptIndexReaderCostFactor:
			sctx.GetSessionVars().IndexReaderCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTableReaderCostFactor:
			sctx.GetSessionVars().TableReaderCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTableFullScanCostFactor:
			sctx.GetSessionVars().TableFullScanCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTableRangeScanCostFactor:
			sctx.GetSessionVars().TableRangeScanCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTableRowIDScanCostFactor:
			sctx.GetSessionVars().TableRowIDScanCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTableTiFlashScanCostFactor:
			sctx.GetSessionVars().TableTiFlashScanCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptIndexLookupCostFactor:
			sctx.GetSessionVars().IndexLookupCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptIndexMergeCostFactor:
			sctx.GetSessionVars().IndexMergeCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptSortCostFactor:
			sctx.GetSessionVars().SortCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptTopNCostFactor:
			sctx.GetSessionVars().TopNCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptLimitCostFactor:
			sctx.GetSessionVars().LimitCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptStreamAggCostFactor:
			sctx.GetSessionVars().StreamAggCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptHashAggCostFactor:
			sctx.GetSessionVars().HashAggCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptMergeJoinCostFactor:
			sctx.GetSessionVars().MergeJoinCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptHashJoinCostFactor:
			sctx.GetSessionVars().HashJoinCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptIndexJoinCostFactor:
			sctx.GetSessionVars().IndexJoinCostFactor = state.varValues[i].(float64)
		case vardef.TiDBOptOrderingIdxSelRatio:
			sctx.GetSessionVars().OptOrderingIdxSelRatio = state.varValues[i].(float64)
		case vardef.TiDBOptRiskEqSkewRatio:
			sctx.GetSessionVars().RiskEqSkewRatio = state.varValues[i].(float64)
		case vardef.TiDBOptPreferRangeScan:
			sctx.GetSessionVars().SetAllowPreferRangeScan(state.varValues[i].(bool))
		default:
			return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
		}
	}

	fixControlStrBuilder := strings.Builder{}
	for i, fixID := range state.fixIDs {
		if i > 0 {
			fixControlStrBuilder.WriteString(",")
		}
		fixControlStrBuilder.WriteString(fmt.Sprintf("%v:%v", fixID, state.fixValues[i]))
	}
	fixControlMap, _, err := fixcontrol.ParseToMap(fixControlStrBuilder.String())
	if err != nil {
		return nil, err
	}
	sctx.GetSessionVars().OptimizerFixControl = fixControlMap

	// construct the leading hint and add it into the current stmtNode
	if state.leading2[0] != nil && state.leading2[1] != nil {
		if sel, isSel := stmt.(*ast.SelectStmt); isSel {
			defer func(hintsLen int) {
				sel.TableHints = sel.TableHints[:hintsLen]
			}(len(sel.TableHints))
			leadingHint := &ast.TableOptimizerHint{
				HintName: ast.NewCIStr(hint.HintLeading),
				Tables: []ast.HintTable{
					{
						DBName:    ast.NewCIStr(state.leading2[0].schema),
						TableName: ast.NewCIStr(state.leading2[0].name),
					},
					{
						DBName:    ast.NewCIStr(state.leading2[1].schema),
						TableName: ast.NewCIStr(state.leading2[1].name),
					},
				},
			}
			sel.TableHints = append(sel.TableHints, leadingHint)
		}
	}

	planDigest, planHints, planText, err := GenBriefPlanWithSCtx(sctx, stmt)
	if err != nil {
		return nil, err
	}
	return &genedPlan{
		planDigest: planDigest,
		planText:   planText,
		planHints:  planHints,
	}, nil
}

// adjustVar returns the new value of the variable for plan generation.
func adjustVar(varName string, varVal any) (newVarVal any, err error) {
	switch varName {
	case vardef.TiDBOptIndexScanCostFactor, vardef.TiDBOptIndexReaderCostFactor, vardef.TiDBOptTableReaderCostFactor,
		vardef.TiDBOptTableFullScanCostFactor, vardef.TiDBOptTableRangeScanCostFactor, vardef.TiDBOptTableRowIDScanCostFactor,
		vardef.TiDBOptTableTiFlashScanCostFactor, vardef.TiDBOptIndexLookupCostFactor, vardef.TiDBOptIndexMergeCostFactor,
		vardef.TiDBOptSortCostFactor, vardef.TiDBOptTopNCostFactor, vardef.TiDBOptLimitCostFactor,
		vardef.TiDBOptStreamAggCostFactor, vardef.TiDBOptHashAggCostFactor, vardef.TiDBOptMergeJoinCostFactor,
		vardef.TiDBOptHashJoinCostFactor, vardef.TiDBOptIndexJoinCostFactor:
		// for cost factors, we add add some penalties (5 tims of its current cost) in each step.
		v := varVal.(float64)
		if v >= 1e6 { // avoid too large penalty.
			return v, nil
		}
		return v * 5, nil
	case vardef.TiDBOptOrderingIdxSelRatio, vardef.TiDBOptRiskEqSkewRatio: // range [0, 1], "<=0" means disable
		v := varVal.(float64)
		if v <= 0 {
			return 0.1, nil
		} else if v+0.1 > 1 {
			return v, nil
		}
		// increase 0.1 each step
		return v + 0.1, nil
	case vardef.TiDBOptPreferRangeScan: // flip the switch
		return !varVal.(bool), nil
	}
	return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
}

// adjustFix returns the new value of the fix-control for plan generation.
func adjustFix(fixID uint64, fixVal string) (newFixVal string, err error) {
	switch fixID {
	case fixcontrol.Fix44855, fixcontrol.Fix52869: // flip the switch
		fixVal = strings.ToUpper(strings.TrimSpace(fixVal))
		if fixVal == vardef.Off {
			return vardef.On, nil
		}
		return vardef.Off, nil
	case fixcontrol.Fix45132:
		num, err := strconv.ParseInt(fixVal, 10, 64)
		if err != nil {
			return "", err
		}
		if num <= 10 {
			return fixVal, nil
		}
		// each time become 50% more aggressive.
		return fmt.Sprintf("%v", num/2), nil
	default:
		return "", fmt.Errorf("unsupported fix-control %d in plan generation", fixID)
	}
}

func getStartState(vars []string, fixes []uint64) (*state, error) {
	// use the default values of these vars and fix-controls as the initial state.
	s := &state{varNames: vars, fixIDs: fixes}
	for _, varName := range vars {
		switch varName {
		case vardef.TiDBOptIndexScanCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptIndexScanCostFactor)
		case vardef.TiDBOptIndexReaderCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptIndexReaderCostFactor)
		case vardef.TiDBOptTableReaderCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTableReaderCostFactor)
		case vardef.TiDBOptTableFullScanCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTableFullScanCostFactor)
		case vardef.TiDBOptTableRangeScanCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTableRangeScanCostFactor)
		case vardef.TiDBOptTableRowIDScanCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTableRowIDScanCostFactor)
		case vardef.TiDBOptTableTiFlashScanCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTableTiFlashScanCostFactor)
		case vardef.TiDBOptIndexLookupCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptIndexLookupCostFactor)
		case vardef.TiDBOptIndexMergeCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptIndexMergeCostFactor)
		case vardef.TiDBOptSortCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptSortCostFactor)
		case vardef.TiDBOptTopNCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptTopNCostFactor)
		case vardef.TiDBOptLimitCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptLimitCostFactor)
		case vardef.TiDBOptStreamAggCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptStreamAggCostFactor)
		case vardef.TiDBOptHashAggCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptHashAggCostFactor)
		case vardef.TiDBOptMergeJoinCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptMergeJoinCostFactor)
		case vardef.TiDBOptHashJoinCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptHashJoinCostFactor)
		case vardef.TiDBOptIndexJoinCostFactor:
			s.varValues = append(s.varValues, vardef.DefOptIndexJoinCostFactor)
		case vardef.TiDBOptOrderingIdxSelRatio:
			s.varValues = append(s.varValues, vardef.DefTiDBOptOrderingIdxSelRatio)
		case vardef.TiDBOptRiskEqSkewRatio:
			s.varValues = append(s.varValues, vardef.DefOptRiskEqSkewRatio)
		case vardef.TiDBOptPreferRangeScan:
			s.varValues = append(s.varValues, vardef.DefOptPreferRangeScan)
		default:
			return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
		}
	}

	for _, fixID := range fixes {
		switch fixID {
		case fixcontrol.Fix44855:
			s.fixValues = append(s.fixValues, "OFF")
		case fixcontrol.Fix45132:
			s.fixValues = append(s.fixValues, "1000")
		case fixcontrol.Fix52869:
			s.fixValues = append(s.fixValues, "OFF")
		default:
			return nil, fmt.Errorf("unsupported fix-control %d in plan generation", fixID)
		}
	}
	return s, nil
}

type tableNameExtractor struct {
	defaultSchema string
	tableNames    map[string]*tableName
}

// Enter implements ast.Visitor interface.
func (e *tableNameExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) {
	if name, ok := in.(*ast.TableName); ok {
		t := &tableName{
			schema: name.Schema.L,
			name:   name.Name.L,
		}
		if t.schema == "" {
			t.schema = e.defaultSchema
		}
		if _, ok := e.tableNames[t.String()]; !ok {
			e.tableNames[t.String()] = t
		}
	}
	return in, false
}

// Leave implements ast.Visitor interface.
func (*tableNameExtractor) Leave(in ast.Node) (node ast.Node, ok bool) {
	return in, true
}

// extractSelectTableNames returns the table names in the SELECT statement.
func extractSelectTableNames(defaultSchema string, node ast.StmtNode) []*tableName {
	selStmt, isSel := node.(*ast.SelectStmt)
	if !isSel {
		return nil // only support SELECT statement for now
	}
	extractor := &tableNameExtractor{
		defaultSchema: defaultSchema,
		tableNames:    make(map[string]*tableName),
	}
	selStmt.Accept(extractor)

	names := make([]*tableName, 0, len(extractor.tableNames))
	for _, name := range extractor.tableNames {
		names = append(names, name)
	}
	sort.Slice(names, func(i, j int) bool {
		return names[i].String() < names[j].String()
	})
	return names
}
